{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2019 Google LLC\n",
    "# \n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/GoogleCloudPlatform/keras-idiomatic-programmer/blob/master/notebooks/prestem_deconvolution.ipynb\">\n",
    "<img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
    "\n",
    "Due to the time to train the model, change the Colab runtime to use the GPU:\n",
    "\n",
    "`Runtime -> Change runtime type -> Hardware accelerator -> GPU`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Add Deconvolution Pre-Stem to ResNet50\n",
    "\n",
    "### Background\n",
    "\n",
    "The ResNet50 architecture ([Deep Residual Learning for Image Recognition, 2015](https://arxiv.org/pdf/1707.00600.pdf)) does not learn well (or at all) with small image sizes, such as the CIFAR-10 and CIFAR-100 whose image size is 32x32. The reason is that the feature maps are downsampled too soon in the architecture and become 1x1 (single pixel) before reaching the bottleneck layer prior to the classifier.\n",
    "\n",
    "The ResNet50 was designed for 224x224 but will work well for size 128x128."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "We start by importing the tf.keras modules we will use, along with numpy and the builtin dataset for CIFAR-10.\n",
    "\n",
    "This tutorial will work with both TF 1.X and TF 2.0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports for the model\n",
    "from tensorflow.keras import Sequential, Model\n",
    "from tensorflow.keras.layers import Dense, Conv2DTranspose, BatchNormalization\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from tensorflow.keras.applications import ResNet50\n",
    "\n",
    "# imports for the dataset\n",
    "from tensorflow.keras.datasets import cifar10\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Solution\n",
    "\n",
    "We could upsample the CIFAR-10 images upstream from 32x32 to 128x128, using an interpolation algorithm such as BI-CUBIC --but this 'hardwired' interpolation may not be the best and may introduce artifacts. Additionally, being upstream from the model, it is generally an inefficient method.\n",
    "The ResNet50 was designed for 224x224 and work well for size 128x128, but not for small images such as 32x32.\n",
    "\n",
    "The authors addressed this in their second paper ([Identity Mappings in Deep Residual Networks](https://arxiv.org/pdf/1603.05027.pdf)) with the ResNet56 v2 architecture optimized for CIFAR-10/CIFAR-100. They did this by:\n",
    "\n",
    "    - Reducing the stem convolution from using a large 7x7 filter to a much smaller 3x3\n",
    "      filter that captured coarse details in a much smaller image size (32x32).\n",
    "      \n",
    "    - Changing the residual groups to reduce downsampling of the feature maps, such that\n",
    "      they continued to be 4x4 at the bottleneck layer, as was in the first paper for\n",
    "      ResNet50 v1 for the 224 x 224 images.\n",
    "\n",
    "*Drawback*\n",
    "\n",
    "The drawback with this 2016 approach was that one had to redesign the micro-architecture (meta-parameters) of an existing neural networks to accomodate image size that was substantially different from the original neural network.\n",
    "\n",
    "In recent years, the design of convolutional neural networks have become increasingly more formalized. In addition to a stem group we see the emergence of **pre-stem** groups which can learned transformations of the input to *dynamically* adjust to an existing architecture that was designed for a different image size.\n",
    "\n",
    "*Alternate Solution - Upscale Image Upstream from Model*\n",
    "\n",
    "One could updsample the CIFAR-10 images upstream from 32x32 to 128x128, using an interpolation algorithm such as BI-CUBIC --but this 'hardwired' interpolation may not be the best and may introduce artifacts. Additionally, being upstream from the model, it is generally an inefficient method.\n",
    "\n",
    "### Pre-Stem Solution\n",
    "\n",
    "Instead, one can add a *pre-stem* Group at the bottom (input) layer of a stock ResNet50 to learn the best upsampling using deconvolution (also known as a transpose convolution). \n",
    "\n",
    "Additionally, the pre-stem becomes part of the model graph."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1\n",
    "\n",
    "We start with a stock `ResNet50` without a classifier and reset the input shape to (128, 128, 3), which we name as the `base` model.\n",
    "\n",
    "Next, we add the classifier layer as a Dense layer of 10 nodes, which we name as the `resnet` model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get a pre-built ResNet50 w/o the top layer (classifer) and input shape configured for 128 x 128\n",
    "base = ResNet50(include_top=False, input_shape=(128, 128, 3), pooling='max')\n",
    "\n",
    "# Add a new classifier (top) layer for the 10 classes in CIFAR-10\n",
    "outputs = Dense(10, activation='softmax')(base.output)\n",
    "\n",
    "# Rebuild the model with the new classifier\n",
    "resnet = Model(base.input, outputs)\n",
    "resnet.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 2\n",
    "\n",
    "We construct a pre-stem group using two deconvolutions (also called a transpose convolution):\n",
    "\n",
    "    1. First deconvolution takes (32, 32, 3) as input and upsamples to (64, 64, 3).\n",
    "    2. Second deconvolution upsamples to (128, 128, 3)\n",
    "    3. We use the add() method to attach the pre-stem to the resnet model.\n",
    "    \n",
    "Essentially, the pre-stem takes the (32, 32, 3) CIFAR-10 inputs and outputs (128, 128, 3) which is then the input to the resnet model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the pre-stem as a Sequential model\n",
    "model = Sequential()\n",
    "\n",
    "# This is the first deconvolution, which takes the (32, 32, 3) CIFAR-10 input and outputs (64, 64, 3)\n",
    "model.add(Conv2DTranspose(3, (3, 3), strides=2, padding='same', activation='relu', input_shape=(32,32,3)))\n",
    "model.add(BatchNormalization())\n",
    "\n",
    "# This is the second deconvolution which outputs (128, 128, 3) which matches the input to our ResNet50 model\n",
    "model.add(Conv2DTranspose(3, (3, 3), strides=2, padding='same', activation='relu'))\n",
    "model.add(BatchNormalization())\n",
    "\n",
    "# Add the ResNet50 model as the remaining layers and rebuild\n",
    "model.add(resnet)\n",
    "model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=0.001), metrics=['acc'])\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare the Data\n",
    "\n",
    "We will use the CIFAR-10 builtin dataset and normalize the image data and one-hot encode the labels upstream from the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n",
    "x_train = (x_train / 255.0).astype(np.float32)\n",
    "x_test  = (x_test  / 255.0).astype(np.float32)\n",
    "\n",
    "y_train = to_categorical(y_train)\n",
    "y_test  = to_categorical(y_test)\n",
    "\n",
    "print(x_train.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the Model\n",
    "\n",
    "Let's partially train the model to demonstrate how a pre-stem works. First, for ResNet50 I find a reliable choice of optimizer and learning rate is the Adam optimizer with a learning rate = 0.001. While the batch normalization should provide the ability for higher learning rates, I find with higher ones on ResNet50 it per epoch loss does not converge.\n",
    "\n",
    "\n",
    "We will then use the fit() method for a small number of epochs (5) and set aside 10% of the training data for the per epoch validation data.\n",
    "\n",
    "From my test run, I got:\n",
    "\n",
    "    Epoch 1: 27.7%\n",
    "    Epoch 2: 33.8%\n",
    "    Epoch 3: 42.9%\n",
    "    Epoch 4: 35.9%  -- dropped into a less good local minima\n",
    "    Epoch 5: 49.1%  -- found a better local minima to dive into"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fit(x_train, y_train, epochs=5, batch_size=32, verbose=1, validation_split=0.1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
